# -*- coding: utf-8 -*- 
"""
========================================================================================================================
@project : my-sanic
@file: BaseModelContoller
@Author: mengying
@email: 652044581@qq.com
@date: 2023/3/16 23:05
@desc: 数据库操作的基类
========================================================================================================================
"""

from typing import Dict, List, Union
from abc import ABCMeta, abstractmethod
from database import commons
from addict import Dict as asDict
from utils.myTimeFormat import MyTime
from utils.myLog import SeqLog
from bson import ObjectId

seqLog = SeqLog().get_logger()


class MongoError(BaseException):
    """所有错误的基类，必须重新错误处理和日志记录"""

    def __init__(self, message):
        self.message = message
        self.error_handler()
        self.log_record()

    def error_handler(self):
        pass

    def log_record(self):
        seqLog.debug("MongoError记录错误消息: %s" % str(self.message))


class Transform:
    """将数据里面含有object的id转换为字符串"""

    @staticmethod
    def convert_obj_string(data, keeper_id=True):
        # 判断传入的数据不为空
        if not data:
            return data

        if isinstance(data, list):
            for item in data:
                if item.get("_id", None):
                    if keeper_id:
                        item["_id"] = str(item["_id"])
                    else:
                        item.pop("_id")
        else:
            if data.get("_id", None):
                if keeper_id:
                    data["_id"] = str(data["_id"])
                else:
                    data.pop("_id")
        return data


class BaseMongoMethod(metaclass=ABCMeta):
    _model_class = None

    @property
    def _db(self):
        return commons.db

    @abstractmethod
    def template(self):
        return {}

    @property
    def collection_name(self) -> str:
        # 从子类中获取集合名称
        if not self._model_class:
            raise Exception("请定义你的model_class")

        _func = getattr(self._model_class, "collection_name", None)

        # 判断是否定义collection_name, 没有定义就取类名
        return _func.__call__() if _func else self._model_class.__name__


class AsyncMongoMethod(BaseMongoMethod):

    async def insert_one(self, data: Dict, validation: Dict = None, create_time_auto: bool = False) -> str:
        """
        @parma: data字典数据
        @parma: validation 校验条件
        @desc: 查询单条数据
        """
        # 从子类中加载配置的模板
        model_template = self.template().copy()
        model_template.update(data)

        if validation:
            threshold = await self.find_one(validation)
            if threshold:
                raise Exception("数据:%s插入, 校验%s失败" % (self.collection_name, str(validation)))

        if create_time_auto:
            model_template.update({"create_time": MyTime.TimeFormat(MyTime.dateTimeType)})

        result = await self._db[self.collection_name].insert_one(model_template)

        return str(result.inserted_id)

    async def find_one(self, condition: Dict, field: List = None, keeper_id: bool = True) -> Dict:
        """
        @params: condition  查询条件
        @params: field  查询过滤字段
        @params: keeper_id 是否保留_id, 默认保留
        @desc: 查询单条数据
        """
        field_dict = {}
        if field:
            field_dict.update({key: 1 for key in field})
            data = await self._db[self.collection_name].find_one(condition, field_dict)
        else:
            data = await self._db[self.collection_name].find_one(condition)

        return Transform.convert_obj_string(data, keeper_id=keeper_id)

    async def insert_many(self, collection_list: List[Dict], create_time_auto: bool = False) -> List:
        """
        @params: collection_list 字典数据的列表
        @desc:批量插入数据
        """
        container = []
        for collection in collection_list:
            model_template = self.template().copy()

            if create_time_auto:
                model_template['create_time'] = MyTime.TimeFormat(MyTime.dateTimeType)

            container.append(model_template.update(collection))

        result = await self._db[self.collection_name].insert_many(container)
        return [str(_id) for _id in result.inserted_ids]

    async def find(self, condition: Dict, field: List = None, page: int = 1,
                   size: int = 0, sort_list: List = None, total_count: bool = False,
                   keeper_id: bool = True) -> Union[Dict, List]:
        """
        @params: condition  查询条件
        @params: field  查询过滤字段 默认为全部查询
        @params: page 页码
        @params: size 每页大小， 默认0 不分页
        @params: sort_list  排序列表   [field(正序), -field(倒序)]
        @params: total_count 是否返回查询到的数据总量,分页前
        @params: keeper_id 是否保留_id, 默认保留
        @desc: 查询全部数据
        """
        field_dict = {}
        if field:
            field_dict.update({key: 1 for key in field})
            cursor = self._db[self.collection_name].find(condition, field_dict)
        else:
            cursor = self._db[self.collection_name].find(condition)

        # 排序
        start = (page - 1) * size
        if sort_list:
            sort_key = []
            for key in sort_list:
                s = key[:1]
                if s == '-':
                    sort_key.append((key[1:], -1))
                else:
                    sort_key.append((key, 1))
            cursor = cursor.sort(sort_key)

        # 分页
        if size > 0:
            cursor = cursor.skip(start).limit(size)

        data_list = []
        async for document in cursor:
            data_list.append(document)

        # object类型字段转换类型
        data_list = Transform.convert_obj_string(data_list, keeper_id=keeper_id)

        if total_count:
            my_dict = asDict()
            my_dict.data = data_list
            my_dict.total = await self.count(condition)
            return my_dict.to_dict()

        return data_list

    async def count(self, condition: Dict) -> int:
        """
        @params: condition  查询条件
        @desc: 统计符合条件的文件数量
        """
        return await self._db[self.collection_name].count_documents(condition)

    async def update(self, condition: Dict, update: Dict, upsert: bool = False, update_time_auto: bool = False) -> int:
        """
        @params: condition  查询条件
        @params: update  更新内容 {"$push" :{"name": "xiaoming"} 也可以 {"name": "xiaoming"}
        @params: upsert  如果没有更新就插入
        @desc: 更新符合条件的文档
        """
        # 分离自定义操作符和数据
        orders = list(update.keys())
        if str(orders[0]).startswith("$"):
            cmd = orders[0]
            update = list(update.values())[0]
        else:
            cmd = "$set"

        # 返回修改数量
        if update_time_auto:
            update["update_time"] = MyTime.TimeFormat(MyTime.dateTimeType)

        result = await self._db[self.collection_name].update_many(condition, {cmd: update}, upsert=upsert)
        return result.matched_count

    async def delete(self, condition: Dict) -> int:
        """
        @params: condition  查询条件
        @desc: 删除符合条件的文档
        """
        result = await self._db[self.collection_name].delete_many(condition)
        return result.deleted_count

    async def aggregate(self, pipeline: Dict):
        """
        @params: pipeline复杂得管道查询，可以实现联表查询功能
        @desc: 聚合查询功能
        """
        result = await self._db[self.collection_name].aggregate(pipeline)
        return list(result)

    async def backups(self, condition: Dict, node: str = "default") -> None:
        """
        @params: condition  查询条件
        @params: node  备份的节点数据
        @desc: 数据备份
        """
        backups_data = await self.find_one(condition=condition)

        collection_name_backups = "_".join([self.collection_name, node, "backups"])

        if not backups_data:
            return None

        condition = {
            "_id": ObjectId(backups_data["_id"])
        }
        await self._db[collection_name_backups].delete_many(condition)
        backups_data["_id"] = ObjectId(backups_data["_id"])
        return await self._db[collection_name_backups].insert_one(backups_data)

    async def restore(self, condition: Dict, node: str = "default") -> None:
        """
        @params: condition  查询条件
        @desc: 数据还原
        """
        collection_name_backups = "_".join([self.collection_name, node, "backups"])
        restore_data = await self._db[collection_name_backups].find_one(condition)

        if not restore_data:
            return
        condition = {
            "_id": ObjectId(restore_data["_id"])
        }
        await self._db[self.collection_name].delete_many(condition)
        restore_data["_id"] = ObjectId(restore_data["_id"])
        return await self._db[self.collection_name].insert_one(restore_data)


class SyncMongoMethod(BaseMongoMethod):

    def insert_one(self, data: Dict, validation: Dict = None, create_time_auto: bool = False) -> str:
        """
        @parma: data字典数据
        @parma: validation 校验条件
        @desc: 查询单条数据
        """
        # 从子类中加载配置的模板
        model_template = self.template().copy()
        model_template.update(data)

        if validation:
            threshold = self.find_one(validation)
            if threshold:
                raise Exception("数据:%s插入, 校验%s失败" % (self.collection_name, str(validation)))

        if create_time_auto:
            model_template.update({"create_time": MyTime.TimeFormat(MyTime.dateTimeType)})

        result = self._db[self.collection_name].insert_one(model_template)
        return str(result.inserted_id)

    def find_one(self, condition: Dict, field: List = None, keeper_id: bool = True) -> Dict:
        """
        @params: condition  查询条件
        @params: field  查询过滤字段
        @params: pk 是否返回主键
        @params: keeper_id 是否保留_id, 默认保留
        @desc: 查询单条数据
        """
        field_dict = {}
        if field:
            field_dict.update({key: 1 for key in field})
            data = self._db[self.collection_name].find_one(condition, field_dict)
        else:
            data = self._db[self.collection_name].find_one(condition)

        return Transform.convert_obj_string(data, keeper_id=keeper_id)

    def insert_many(self, collection_list: List[Dict], create_time_auto: bool = False) -> List:
        """
        @params: collection_list 字典数据的列表
        @desc:批量插入数据
        """
        container = []
        for collection in collection_list:
            model_template = self.template().copy()

            if create_time_auto:
                model_template['create_time'] = MyTime.TimeFormat(MyTime.dateTimeType)

            container.append(model_template.update(collection))

        result = self._db[self.collection_name].insert_many(container)
        return [str(_id) for _id in result.inserted_ids]

    def find(self, condition: Dict, field: List = None, page: int = 1,
             size: int = 0, sort_list: List = None, total_count: bool = False,
             keeper_id: bool = True) -> Union[Dict, List]:
        """
        @params: condition  查询条件
        @params: field  查询过滤字段 默认为全部查询
        @params: page 页码
        @params: size 每页大小， 默认0 不分页
        @params: sort_list  排序列表   [field(正序), -field(倒序)]
        @params: total_count 是否返回查询到的数据总量,分页前
        @params: keeper_id 是否保留_id, 默认保留
        @desc: 查询全部数据
        """
        field_dict = {}
        if field:
            field_dict.update({key: 1 for key in field})
            cursor = self._db[self.collection_name].find(condition, field_dict)
        else:
            cursor = self._db[self.collection_name].find(condition)

        # 排序
        start = (page - 1) * size
        if sort_list:
            sort_key = []
            for key in sort_list:
                s = key[:1]
                if s == '-':
                    sort_key.append((key[1:], -1))
                else:
                    sort_key.append((key, 1))
            cursor = cursor.sort(sort_key)

        # 分页
        if size > 0:
            cursor = cursor.skip(start).limit(size)

        data_list = []
        for document in cursor:
            data_list.append(document)

        # object类型字段转换类型
        data_list = Transform.convert_obj_string(data_list, keeper_id=keeper_id)

        if total_count:
            my_dict = asDict()
            my_dict.data = data_list
            my_dict.total = self.count(condition)
            return my_dict.to_dict()

        return data_list

    def count(self, condition: Dict) -> int:
        """
        @params: condition  查询条件
        @desc: 统计符合条件的文件数量
        """
        return self._db[self.collection_name].count_documents(condition)

    def update(self, condition: Dict, update: Dict, upsert: bool = False, update_time_auto: bool = False) -> int:
        """
        @params: condition  查询条件
        @params: update  更新内容 {"$push" :{"name": "xiaoming"} 也可以 {"name": "xiaoming"}
        @params: upsert  如果没有更新就插入
        @desc: 更新符合条件的文档
        """
        # 分离自定义操作符和数据
        orders = list(update.keys())
        if str(orders[0]).startswith("$"):
            cmd = orders[0]
            update = list(update.values())[0]
        else:
            cmd = "$set"

        # 返回修改数量
        if update_time_auto:
            update["update_time"] = MyTime.TimeFormat(MyTime.dateTimeType)

        # 返回修改数量
        result = self._db[self.collection_name].update_many(condition, {cmd: update}, upsert=upsert)
        return result.matched_count

    def delete(self, condition: Dict) -> int:
        """
        @params: condition  查询条件
        @desc: 删除符合条件的文档
        """
        result = self._db[self.collection_name].delete_many(condition)
        return result.deleted_count

    def aggregate(self, pipeline: Dict):
        """
        @params: pipeline复杂得管道查询，可以实现联表查询功能
        @desc: 聚合查询功能
        """
        result = self._db[self.collection_name].aggregate(pipeline)
        return list(result)

    def backups(self, condition: Dict, node: str = "default") -> None:
        """
        @params: condition  查询条件
        @params: node  备份的节点数据
        @desc: 数据备份
        """
        backups_data = self.find_one(condition=condition)

        collection_name_backups = "_".join([self.collection_name, node, "backups"])

        if not backups_data:
            return None

        condition = {
            "_id": ObjectId(backups_data["_id"])
        }
        self._db[collection_name_backups].delete_many(condition)
        backups_data["_id"] = ObjectId(backups_data["_id"])
        return self._db[collection_name_backups].insert_one(backups_data)

    def restore(self, condition: Dict, node: str = "default") -> None:
        """
        @params: condition  查询条件
        @params: node  还原数据节点的数据
        @desc: 数据还原
        """
        collection_name_backups = "_".join([self.collection_name, node, "backups"])
        restore_data = self._db[collection_name_backups].find_one(condition)

        if not restore_data:
            return

        condition = {
            "_id": ObjectId(restore_data["_id"])
        }
        self._db[self.collection_name].delete_many(condition)
        restore_data["_id"] = ObjectId(restore_data["_id"])
        return self._db[self.collection_name].insert_one(restore_data)
